Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

This PR refactors the clamp operation implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::clamp return StatusOr<XLATensorPtr>
  • Improve error handling
    • Inline GetMinMaxValues() function
    • Move the check to a new CheckClampMinOrMax() function

Example

a = torch.rand(10, device=device)
torch.ops.aten.clamp.default(a)

Before:

Traceback (most recent call last):
  File "examples/clamp.py", line 6, in <module>
    torch.ops.aten.clamp.default(a)
  File "torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Check failed: min || max: At least one of 'min' or 'max' must not be None (at torch_xla/csrc/tensor_methods.cpp:183)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/clamp.py", line 6, in <module>
    torch.ops.aten.clamp.default(a)
  File "torch/_ops.py", line 840, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: clamp(): expected at least one of `min` or `max` arguments to be specified.

Status Propagation Trace:
    From: CheckClampMinOrMax at torch_xla/csrc/tensor_methods.cpp:496 (error: clamp(): expected at least one of `min` or `max` arguments to be specified.)
    From: clamp at torch_xla/csrc/tensor_methods.cpp:1357
    From: clamp at torch_xla/csrc/aten_xla_type.cpp:1377

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

@ysiraichi ysiraichi force-pushed the ysiraichi/better-error-clamp branch from 214b375 to 9c7c764 Compare October 14, 2025 19:39
@ysiraichi ysiraichi merged commit 87e631a into master Oct 15, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants